{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code for **\"Blind restoration of a JPEG-compressed image\"** and **\"Blind image denoising\"** figures. Select `fname` below to switch between the two.\n",
    "\n",
    "- To see overfitting set `num_iter` to a large value."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "#os.environ['CUDA_VISIBLE_DEVICES'] = '3'\n",
    "\n",
    "import numpy as np\n",
    "from models import *\n",
    "\n",
    "import torch\n",
    "import torch.optim\n",
    "\n",
    "from skimage.measure import compare_psnr\n",
    "from utils.denoising_utils import *\n",
    "\n",
    "torch.backends.cudnn.enabled = True\n",
    "torch.backends.cudnn.benchmark =True\n",
    "dtype = torch.cuda.FloatTensor\n",
    "\n",
    "imsize =-1\n",
    "PLOT = True\n",
    "sigma = 25\n",
    "sigma_ = sigma/255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# deJPEG \n",
    "# fname = 'data/denoising/snail.jpg'\n",
    "\n",
    "## denoising\n",
    "fname = 'data/denoising/F16_GT.png'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if fname == 'data/denoising/snail.jpg':\n",
    "    img_noisy_pil = crop_image(get_image(fname, imsize)[0], d=32)\n",
    "    img_noisy_np = pil_to_np(img_noisy_pil)\n",
    "    \n",
    "    # As we don't have ground truth\n",
    "    img_pil = img_noisy_pil\n",
    "    img_np = img_noisy_np\n",
    "    \n",
    "    if PLOT:\n",
    "        plot_image_grid([img_np], 4, 5);\n",
    "        \n",
    "elif fname == 'data/denoising/F16_GT.png':\n",
    "    # Add synthetic noise\n",
    "    img_pil = crop_image(get_image(fname, imsize)[0], d=32)\n",
    "    img_np = pil_to_np(img_pil)\n",
    "    \n",
    "    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)\n",
    "    \n",
    "    if PLOT:\n",
    "        plot_image_grid([img_np, img_noisy_np], 4, 6);\n",
    "else:\n",
    "    assert False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT = 'noise' # 'meshgrid'\n",
    "pad = 'reflection'\n",
    "OPT_OVER = 'net' # 'net,input'\n",
    "\n",
    "reg_noise_std = 1./30. # set to 1./20. for sigma=50\n",
    "LR = 0.01\n",
    "\n",
    "OPTIMIZER='adam' # 'LBFGS'\n",
    "show_every = 100\n",
    "exp_weight=0.99\n",
    "\n",
    "if fname == 'data/denoising/snail.jpg':\n",
    "    num_iter = 2400\n",
    "    input_depth = 3\n",
    "    figsize = 5 \n",
    "    \n",
    "    net = skip(\n",
    "                input_depth, 3, \n",
    "                num_channels_down = [8, 16, 32, 64, 128], \n",
    "                num_channels_up   = [8, 16, 32, 64, 128],\n",
    "                num_channels_skip = [0, 0, 0, 4, 4], \n",
    "                upsample_mode='bilinear',\n",
    "                need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')\n",
    "\n",
    "    net = net.type(dtype)\n",
    "\n",
    "elif fname == 'data/denoising/F16_GT.png':\n",
    "    num_iter = 3000\n",
    "    input_depth = 32 \n",
    "    figsize = 4 \n",
    "    \n",
    "    \n",
    "    net = get_net(input_depth, 'skip', pad,\n",
    "                  skip_n33d=128, \n",
    "                  skip_n33u=128, \n",
    "                  skip_n11=4, \n",
    "                  num_scales=5,\n",
    "                  upsample_mode='bilinear').type(dtype)\n",
    "\n",
    "else:\n",
    "    assert False\n",
    "    \n",
    "net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach()\n",
    "\n",
    "# Compute number of parameters\n",
    "s  = sum([np.prod(list(p.size())) for p in net.parameters()]); \n",
    "print ('Number of params: %d' % s)\n",
    "\n",
    "# Loss\n",
    "mse = torch.nn.MSELoss().type(dtype)\n",
    "\n",
    "img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "net_input_saved = net_input.detach().clone()\n",
    "noise = net_input.detach().clone()\n",
    "out_avg = None\n",
    "last_net = None\n",
    "psrn_noisy_last = 0\n",
    "\n",
    "i = 0\n",
    "def closure():\n",
    "    \n",
    "    global i, out_avg, psrn_noisy_last, last_net, net_input\n",
    "    \n",
    "    if reg_noise_std > 0:\n",
    "        net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n",
    "    \n",
    "    out = net(net_input)\n",
    "    \n",
    "    # Smoothing\n",
    "    if out_avg is None:\n",
    "        out_avg = out.detach()\n",
    "    else:\n",
    "        out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)\n",
    "            \n",
    "    total_loss = mse(out, img_noisy_torch)\n",
    "    total_loss.backward()\n",
    "        \n",
    "    \n",
    "    psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0]) \n",
    "    psrn_gt    = compare_psnr(img_np, out.detach().cpu().numpy()[0]) \n",
    "    psrn_gt_sm = compare_psnr(img_np, out_avg.detach().cpu().numpy()[0]) \n",
    "    \n",
    "    # Note that we do not have GT for the \"snail\" example\n",
    "    # So 'PSRN_gt', 'PSNR_gt_sm' make no sense\n",
    "    print ('Iteration %05d    Loss %f   PSNR_noisy: %f   PSRN_gt: %f PSNR_gt_sm: %f' % (i, total_loss.item(), psrn_noisy, psrn_gt, psrn_gt_sm), '\\r', end='')\n",
    "    if  PLOT and i % show_every == 0:\n",
    "        out_np = torch_to_np(out)\n",
    "        plot_image_grid([np.clip(out_np, 0, 1), \n",
    "                         np.clip(torch_to_np(out_avg), 0, 1)], factor=figsize, nrow=1)\n",
    "        \n",
    "        \n",
    "    \n",
    "    # Backtracking\n",
    "    if i % show_every:\n",
    "        if psrn_noisy - psrn_noisy_last < -5: \n",
    "            print('Falling back to previous checkpoint.')\n",
    "\n",
    "            for new_param, net_param in zip(last_net, net.parameters()):\n",
    "                net_param.data.copy_(new_param.cuda())\n",
    "\n",
    "            return total_loss*0\n",
    "        else:\n",
    "            last_net = [x.detach().cpu() for x in net.parameters()]\n",
    "            psrn_noisy_last = psrn_noisy\n",
    "            \n",
    "    i += 1\n",
    "\n",
    "    return total_loss\n",
    "\n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_np = torch_to_np(net(net_input))\n",
    "q = plot_image_grid([np.clip(out_np, 0, 1), img_np], factor=13);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
